{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solution via pseudo inverse: tensor([ 1.0766,  0.8976, -0.9582])\n",
      "Loss at step 0: 16.155038833618164\n",
      "Loss at step 100: 0.238687664270401\n",
      "Loss at step 200: 0.217646986246109\n",
      "Loss at step 300: 0.21724411845207214\n",
      "Loss at step 400: 0.2172362506389618\n",
      "Loss at step 500: 0.21723619103431702\n",
      "Loss at step 600: 0.21723611652851105\n",
      "Loss at step 700: 0.21723619103431702\n",
      "Loss at step 800: 0.21723619103431702\n",
      "Loss at step 900: 0.21723619103431702\n",
      "The solution via gradient descent is tensor([ 1.0766,  0.8976, -0.9582])\n"
     ]
    }
   ],
   "source": [
    "X = torch.tensor([[0.11, 0.09], [0.01, 0.02], [0.98, 0.91],\n",
    "              [0.12, 0.21], [0.98, 0.99], [0.85, 0.87],\n",
    "              [0.03, 0.14], [0.55, 0.45], [0.49, 0.51], \n",
    "              [0.99, 0.01], [0.02, 0.89], [0.31, 0.47],\n",
    "              [0.55, 0.29], [0.87, 0.76], [0.63, 0.24]], dtype=torch.float)\n",
    "X = torch.column_stack((X, torch.ones(15))) \n",
    "y = torch.tensor([-0.8, -0.97, 0.89, -0.67, 0.97, 0.72,\n",
    "              -0.83, 0.00, 0.00, 0.00, -0.09, -0.22, \n",
    "              -0.16, 0.63, 0.37], dtype=torch.float)\n",
    "\n",
    "# Let us compute solution using pseudo inverse\n",
    "solution_pseudo = torch.matmul(torch.matmul(\n",
    "    torch.linalg.inv(torch.matmul(X.T, X)), X.T) , y)\n",
    "print(\"Solution via pseudo inverse: {}\".format(solution_pseudo)) \n",
    "\n",
    "\n",
    "y = y.reshape((-1, 1))\n",
    "\n",
    "# Let us define the torch module\n",
    "class LinearModel(torch.nn.Module):\n",
    "    def __init__(self, num_features):\n",
    "        super(LinearModel, self).__init__()\n",
    "        self.w = torch.nn.Parameter(\n",
    "            torch.randn(num_features, 1))\n",
    "    \n",
    "    def forward(self, X):\n",
    "        \"\"\"\n",
    "        In the forward function we accept a Tensor of input data \n",
    "        and we must return a Tensor of output data. \n",
    "        We can use Modules defined in the constructor as\n",
    "        well as arbitrary operators on Tensors.\n",
    "        \"\"\"\n",
    "        y_pred  = torch.mm(X, self.w) # Computes Xw\n",
    "        return y_pred\n",
    "\n",
    "num_unknowns = 3\n",
    "model =  LinearModel(num_features=num_unknowns)\n",
    "# Let us use  Pytorch MSE loss function\n",
    "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)\n",
    "\n",
    "# Train model iteratively\n",
    "num_steps = 1000\n",
    "for step in range(num_steps):\n",
    "    y_pred = model(X)\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    if step % 100 == 0:\n",
    "        print(\"Loss at step {}: {}\".format(step, loss))\n",
    "    \n",
    "    # Zero the gradients before running the backward pass.\n",
    "    optimizer.zero_grad()\n",
    "    # Compute the gradients for this step\n",
    "    loss.backward()\n",
    "    # Gradient descent\n",
    "    optimizer.step()\n",
    "\n",
    "solution_gd = torch.squeeze(model.w.data)\n",
    "print(\"The solution via gradient descent is {}\".format(solution_gd))\n",
    "\n",
    "assert torch.allclose(solution_pseudo, solution_gd)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
